{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine-Tuning a BERT Model and Create a Text Classifier\n",
    "\n",
    "In the previous section, we've already performed the Feature Engineering to create BERT embeddings from the `reviews_body` text using the pre-trained BERT model, and split the dataset into train, validation and test files. To optimize for Tensorflow training, we saved the files in TFRecord format. \n",
    "\n",
    "Now, let’s fine-tune the BERT model to our Customer Reviews Dataset and add a new classification layer to predict the `star_rating` for a given `review_body`.\n",
    "\n",
    "![BERT Training](img/bert_training.png)\n",
    "\n",
    "As mentioned earlier, BERT’s attention mechanism is called a Transformer. This is, not coincidentally, the name of the popular BERT Python library, “Transformers,” maintained by a company called HuggingFace. \n",
    "\n",
    "We will use a variant of BERT called [**DistilBert**](https://arxiv.org/pdf/1910.01108.pdf) which requires less memory and compute, but maintains very good accuracy on our dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import random\n",
    "import pandas as pd\n",
    "from glob import glob\n",
    "import argparse\n",
    "import json\n",
    "import subprocess\n",
    "import sys\n",
    "import os\n",
    "import tensorflow as tf\n",
    "from transformers import DistilBertTokenizer\n",
    "from transformers import TFDistilBertForSequenceClassification\n",
    "from transformers import DistilBertConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r max_seq_length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    max_seq_length\n",
    "except NameError:\n",
    "    print(\"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")\n",
    "    print(\"[ERROR] Please run the notebooks in the PREPARE section before you continue.\")\n",
    "    print(\"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(max_seq_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def select_data_and_label_from_record(record):\n",
    "    x = {\n",
    "        \"input_ids\": record[\"input_ids\"],\n",
    "        \"input_mask\": record[\"input_mask\"],\n",
    "        #        'segment_ids': record['segment_ids']\n",
    "    }\n",
    "    y = record[\"label_ids\"]\n",
    "\n",
    "    return (x, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def file_based_input_dataset_builder(channel, input_filenames, pipe_mode, is_training, drop_remainder):\n",
    "\n",
    "    # For training, we want a lot of parallel reading and shuffling.\n",
    "    # For eval, we want no shuffling and parallel reading doesn't matter.\n",
    "\n",
    "    if pipe_mode:\n",
    "        print(\"***** Using pipe_mode with channel {}\".format(channel))\n",
    "        from sagemaker_tensorflow import PipeModeDataset\n",
    "\n",
    "        dataset = PipeModeDataset(channel=channel, record_format=\"TFRecord\")\n",
    "    else:\n",
    "        print(\"***** Using input_filenames {}\".format(input_filenames))\n",
    "        dataset = tf.data.TFRecordDataset(input_filenames)\n",
    "\n",
    "    dataset = dataset.repeat(100)\n",
    "    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)\n",
    "\n",
    "    name_to_features = {\n",
    "        \"input_ids\": tf.io.FixedLenFeature([max_seq_length], tf.int64),\n",
    "        \"input_mask\": tf.io.FixedLenFeature([max_seq_length], tf.int64),\n",
    "        #      \"segment_ids\": tf.io.FixedLenFeature([max_seq_length], tf.int64),\n",
    "        \"label_ids\": tf.io.FixedLenFeature([], tf.int64),\n",
    "    }\n",
    "\n",
    "    def _decode_record(record, name_to_features):\n",
    "        \"\"\"Decodes a record to a TensorFlow example.\"\"\"\n",
    "        return tf.io.parse_single_example(record, name_to_features)\n",
    "\n",
    "    dataset = dataset.apply(\n",
    "        tf.data.experimental.map_and_batch(\n",
    "            lambda record: _decode_record(record, name_to_features),\n",
    "            batch_size=8,\n",
    "            drop_remainder=drop_remainder,\n",
    "            num_parallel_calls=tf.data.experimental.AUTOTUNE,\n",
    "        )\n",
    "    )\n",
    "\n",
    "    dataset.cache()\n",
    "\n",
    "    if is_training:\n",
    "        dataset = dataset.shuffle(seed=42, buffer_size=10, reshuffle_each_iteration=True)\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = \"./data-tfrecord/bert-train\"\n",
    "train_data_filenames = glob(\"{}/*.tfrecord\".format(train_data))\n",
    "print(\"train_data_filenames {}\".format(train_data_filenames))\n",
    "\n",
    "train_dataset = file_based_input_dataset_builder(\n",
    "    channel=\"train\", input_filenames=train_data_filenames, pipe_mode=False, is_training=True, drop_remainder=False\n",
    ").map(select_data_and_label_from_record)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "validation_data = \"./data-tfrecord/bert-validation\"\n",
    "validation_data_filenames = glob(\"{}/*.tfrecord\".format(validation_data))\n",
    "print(\"validation_data_filenames {}\".format(validation_data_filenames))\n",
    "\n",
    "validation_dataset = file_based_input_dataset_builder(\n",
    "    channel=\"validation\",\n",
    "    input_filenames=validation_data_filenames,\n",
    "    pipe_mode=False,\n",
    "    is_training=False,\n",
    "    drop_remainder=False,\n",
    ").map(select_data_and_label_from_record)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data = \"./data-tfrecord/bert-test\"\n",
    "test_data_filenames = glob(\"{}/*.tfrecord\".format(test_data))\n",
    "print(test_data_filenames)\n",
    "\n",
    "test_dataset = file_based_input_dataset_builder(\n",
    "    channel=\"test\", input_filenames=test_data_filenames, pipe_mode=False, is_training=False, drop_remainder=False\n",
    ").map(select_data_and_label_from_record)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Specify Manual Hyper-Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 1\n",
    "steps_per_epoch = 10\n",
    "validation_steps = 10\n",
    "test_steps = 10\n",
    "freeze_bert_layer = True\n",
    "learning_rate = 3e-5\n",
    "epsilon = 1e-08"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Pretrained BERT Model \n",
    "https://huggingface.co/transformers/pretrained_models.html "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "CLASSES = [1, 2, 3, 4, 5]\n",
    "\n",
    "config = DistilBertConfig.from_pretrained(\n",
    "    \"distilbert-base-uncased\",\n",
    "    num_labels=len(CLASSES),\n",
    "    id2label={0: 1, 1: 2, 2: 3, 3: 4, 4: 5},\n",
    "    label2id={1: 0, 2: 1, 3: 2, 4: 3, 5: 4},\n",
    ")\n",
    "print(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer_model = TFDistilBertForSequenceClassification.from_pretrained(\"distilbert-base-uncased\", config=config)\n",
    "\n",
    "input_ids = tf.keras.layers.Input(shape=(max_seq_length,), name=\"input_ids\", dtype=\"int32\")\n",
    "input_mask = tf.keras.layers.Input(shape=(max_seq_length,), name=\"input_mask\", dtype=\"int32\")\n",
    "\n",
    "embedding_layer = transformer_model.distilbert(input_ids, attention_mask=input_mask)[0]\n",
    "X = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(50, return_sequences=True, dropout=0.1, recurrent_dropout=0.1))(\n",
    "    embedding_layer\n",
    ")\n",
    "X = tf.keras.layers.GlobalMaxPool1D()(X)\n",
    "X = tf.keras.layers.Dense(50, activation=\"relu\")(X)\n",
    "X = tf.keras.layers.Dropout(0.2)(X)\n",
    "X = tf.keras.layers.Dense(len(CLASSES), activation=\"softmax\")(X)\n",
    "\n",
    "model = tf.keras.Model(inputs=[input_ids, input_mask], outputs=X)\n",
    "\n",
    "for layer in model.layers[:3]:\n",
    "    layer.trainable = not freeze_bert_layer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup the Custom Classifier Model Here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
    "\n",
    "metric = tf.keras.metrics.SparseCategoricalAccuracy(\"accuracy\")\n",
    "\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=epsilon)\n",
    "\n",
    "model.compile(optimizer=optimizer, loss=loss, metrics=[metric])\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "callbacks = []\n",
    "\n",
    "log_dir = \"./tmp/tensorboard/\"\n",
    "tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)\n",
    "callbacks.append(tensorboard_callback)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "history = model.fit(\n",
    "    train_dataset,\n",
    "    shuffle=True,\n",
    "    epochs=epochs,\n",
    "    steps_per_epoch=steps_per_epoch,\n",
    "    validation_data=validation_dataset,\n",
    "    validation_steps=validation_steps,\n",
    "    callbacks=callbacks,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Trained model {}\".format(model))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluate on Holdout Test Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_history = model.evaluate(test_dataset, steps=test_steps, callbacks=callbacks)\n",
    "print(test_history)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensorflow_model_dir = \"./tmp/tensorflow/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p $tensorflow_model_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save(tensorflow_model_dir, include_optimizer=False, overwrite=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls -al $tensorflow_model_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!saved_model_cli show --all --dir $tensorflow_model_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !saved_model_cli run --dir $tensorflow_model_dir --tag_set serve --signature_def serving_default \\\n",
    "#     --input_exprs 'input_ids=np.zeros((1,64));input_mask=np.zeros((1,64))'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predict with Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from transformers import DistilBertTokenizer\n",
    "\n",
    "tokenizer = DistilBertTokenizer.from_pretrained(\"distilbert-base-uncased\")\n",
    "\n",
    "sample_review_body = \"This product is terrible.\"\n",
    "\n",
    "encode_plus_tokens = tokenizer.encode_plus(\n",
    "    sample_review_body, padding='max_length', max_length=max_seq_length, truncation=True, return_tensors=\"tf\"\n",
    ")\n",
    "\n",
    "# The id from the pre-trained BERT vocabulary that represents the token.  (Padding of 0 will be used if the # of tokens is less than `max_seq_length`)\n",
    "input_ids = encode_plus_tokens[\"input_ids\"]\n",
    "\n",
    "# Specifies which tokens BERT should pay attention to (0 or 1).  Padded `input_ids` will have 0 in each of these vector elements.\n",
    "input_mask = encode_plus_tokens[\"attention_mask\"]\n",
    "\n",
    "outputs = model.predict(x=(input_ids, input_mask))\n",
    "\n",
    "prediction = [{\"label\": config.id2label[item.argmax()], \"score\": item.max().item()} for item in outputs]\n",
    "\n",
    "print(\"\")\n",
    "print('Predicted star_rating \"{}\" for review_body \"{}\"'.format(prediction[0][\"label\"], sample_review_body))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Release Resources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%html\n",
    "\n",
    "<p><b>Shutting down your kernel for this notebook to release resources.</b></p>\n",
    "<button class=\"sm-command-button\" data-commandlinker-command=\"kernelmenu:shutdown\" style=\"display:none;\">Shutdown Kernel</button>\n",
    "        \n",
    "<script>\n",
    "try {\n",
    "    els = document.getElementsByClassName(\"sm-command-button\");\n",
    "    els[0].click();\n",
    "}\n",
    "catch(err) {\n",
    "    // NoOp\n",
    "}    \n",
    "</script>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%javascript\n",
    "\n",
    "try {\n",
    "    Jupyter.notebook.save_checkpoint();\n",
    "    Jupyter.notebook.session.delete();\n",
    "}\n",
    "catch(err) {\n",
    "    // NoOp\n",
    "}"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (Data Science)",
   "language": "python",
   "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
